
import torch

def energy_score(dataset_in, 
                 dataset_out,
                 net,
                 device):

    dataset_out_len = len(dataset_out.test_loader.dataset)
    dataset_in_len = len(dataset_in.test_loader.dataset)

    pred = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
    y = torch.zeros_like(pred).to(device)
    index = 0
    datasets = [dataset_in.test_loader, dataset_out.test_loader]
    temp = 1

    with torch.no_grad():
        for dataset_index, dataset in enumerate(datasets):
            for batch_idx, (data, labels) in enumerate(dataset):
                data = data.to(device)
                labels = labels.to(device)

                out = net(data)
                energy_score = temp * torch.logsumexp(out / temp, dim=1)
                    
                # Note dataset_index = 0 for In-Dist
                # and dataset_index = 1 for OoD
                pred[index: index + data.shape[0]] = energy_score
                y[index: index + data.shape[0]] = torch.ones_like(labels).to(device) * dataset_index
                index += data.shape[0]

    labels = y.cpu().numpy()
    pred = -pred.cpu().numpy()
    return labels, pred
